package com.ferry.jraft.impl;

import com.alipay.remoting.exception.RemotingException;
import com.ferry.jraft.ConsensusModule;
import com.ferry.jraft.Node;
import com.ferry.jraft.StateMachine;
import com.ferry.jraft.concurrent.RaftThreadFactory;
import com.ferry.jraft.concurrent.RaftThreadPoolExecutor;
import com.ferry.jraft.enums.NodeStatusEnum;
import com.ferry.jraft.model.LogEntry;
import com.ferry.jraft.model.Peer;
import com.ferry.jraft.model.PeerGroup;
import com.ferry.jraft.model.dto.*;
import com.ferry.jraft.model.state.LeaderState;
import com.ferry.jraft.model.state.PersistentState;
import com.ferry.jraft.model.state.VolatileState;
import com.ferry.jraft.rpc.RaftRpcClient;
import com.ferry.jraft.rpc.RaftRpcServer;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.concurrent.TimeUnit.MILLISECONDS;

/**
 * @Author ferry
 * @create 2022/5/21 16:22
 * @description Node的实现类
 */
@Slf4j
@Data
public class NodeImpl implements Node {

    /**
     * 上一次收到心跳的时间（超时则触发选举）
     */
    public volatile long prevElectionTime = 0;

    /**
     * 上一次发送心跳的时间
     */
    public volatile long prevHeartBeatTime = 0;

    /**
     * 选举超时时间
     */
    public volatile long electionTime = 750 + ThreadLocalRandom.current().nextInt(750);

    /**
     * 心跳发送的间隔时间
     */
    public final long heartBeatTick = 250;

    /**
     * 节点的各种状态State
     */
    public LeaderState leaderState = new LeaderState();
    public PersistentState persistentState = new PersistentState();
    public VolatileState volatileState = new VolatileState();

    /**
     * 当前节点的状态
     */
    public volatile NodeStatusEnum nodeStatus = NodeStatusEnum.FOLLOWER;

    /**
     * 节点是否在运行
     */
    public volatile boolean running = false;

    /**
     * 集群信息
     */
    public PeerGroup peerGroup;

    /**
     * 节点的配置信息
     */
    public Peer peerConfig;

    /**
     * 各个模块
     */
    public RaftRpcClient raftRpcClient;
    public RaftRpcServer raftRpcServer;
    public StateMachine stateMachine;
    public ConsensusModule consensusModule;

    private HeartBeatTask heartBeatTask = new HeartBeatTask();
    private ElectionTask electionTask = new ElectionTask();
    private Apply2StateMachineTask apply2StateMachineTask = new Apply2StateMachineTask();

    /**
     * 状态机的应用结果队列
     */
    public LinkedBlockingQueue<Object> applyChannel = new LinkedBlockingQueue<>();

    /**
     * 需要提交的日志队列
     */
    public LinkedBlockingQueue<LogEntry> commitChannel = new LinkedBlockingQueue<>();

    /**
     * 向其他节点真正发起请求的线程池
     */
    public RaftThreadPoolExecutor raftThreadPoolExecutor = new RaftThreadPoolExecutor();

    /**
     * 定时触发上述任务的线程池
     */
    public ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = new ScheduledThreadPoolExecutor(Runtime.getRuntime().availableProcessors(), new RaftThreadFactory());

    @Override
    public void init() throws Throwable {
        log.info("random election timeout:{}", electionTime);

        //初始化模块
        raftRpcClient = new RaftRpcClient();
        raftRpcClient.init();

        raftRpcServer.init();
        consensusModule = new ConsensusImpl(this);

        //执行定时任务：leader向其他节点发送心跳，检测是否有leader向自己发送心跳，尝试再次复制之前失败的日志
        scheduledThreadPoolExecutor.scheduleWithFixedDelay(heartBeatTask, 0, 125, MILLISECONDS);
        scheduledThreadPoolExecutor.scheduleAtFixedRate(electionTask, 2000 + ThreadLocalRandom.current().nextInt(2000), 375, MILLISECONDS);
        raftThreadPoolExecutor.execute(apply2StateMachineTask);

        //初始化任期
        LogEntry logEntry = persistentState.getLogModule().getLast();
        if (logEntry != null) {
            persistentState.setCurrentTerm(logEntry.getTerm());
        }
        running = true;
        log.info("Node {} has been initialized", peerConfig.getId());

        prevElectionTime = System.currentTimeMillis();

    }

    @Override
    public void destroy() throws Throwable {
        raftRpcClient.destroy();
        raftRpcServer.destroy();
        stateMachine.destroy();
        running = false;
        log.info("Node {} has been destroyed", peerConfig.getId());
    }

    @Override
    public void loadStateMachine(StateMachine stateMachine) {
        this.stateMachine = stateMachine;
        log.info("Node {} 's StateMachine is loaded", peerConfig.getId());
    }

    @Override
    public void loadConfig(Peer peerConfig, PeerGroup peerGroup) {
        //节点配置赋值
        this.peerConfig = peerConfig;
        //初始化日志模块
        this.persistentState.setLogModule(new LogModuleImpl(this.peerConfig.getPort()));
        //获取集群配置单例
        this.peerGroup = peerGroup;
        raftRpcServer = new RaftRpcServer(this, peerConfig.getPort());
        log.info("Node {}: Config is loaded", peerConfig.getId());
    }

    @Override
    public AppendEntriesResponse handleAppendEntriesRequest(AppendEntriesRequest appendEntriesRequest) {
        return consensusModule.appendEntries(appendEntriesRequest);
    }

    @Override
    public VoteResponse handleVoteRequest(VoteRequest voteRequest) {
        log.info("Node {}: Receive vote request from Node {}", peerConfig.getId(), voteRequest.getCandidateId());
        return consensusModule.requestVote(voteRequest);
    }

    @Override
    public synchronized ClientResponse handleClientRequest(ClientRequest clientRequest) {
        log.info("Node {}: Receive request from Client", peerConfig.getId());
        if (nodeStatus != NodeStatusEnum.LEADER) {
            log.warn("Node {}: Since self is not leader, the request needs to be redirected", peerConfig.getId());
            try {
                Response redirectResponse = raftRpcClient.send(clientRequest, peerGroup.getLeader().getAddr());
                return (ClientResponse) redirectResponse;
            } catch (RemotingException e) {
                log.error("Node {}: Redirect RPC failed, destination is Node {}", peerConfig.getId(), peerGroup.getLeader().getId());
            }
        } else {
            LogEntry logEntry = LogEntry.builder()
                    .term(persistentState.getCurrentTerm())
                    .value(clientRequest.getParam().toString())
                    .build();

            //写入日志到本地
            persistentState.getLogModule().write(logEntry);
            log.info("Node {}: Write logModule success, logEntry info : {}", peerConfig.getId(), logEntry);

            AtomicInteger appendRes = new AtomicInteger(0);
            List<Future<Boolean>> futureList = new ArrayList<>();

            //分配任务给线程池：追加日志到其他节点
            for (Peer peer : peerGroup.getOtherPeers(peerConfig)) {
                futureList.add(appendEntriesToOtherNodes(peer, logEntry));
            }
            CountDownLatch latch = new CountDownLatch(futureList.size());
            List<Boolean> resList = new ArrayList<>();

            //等待RPC结果返回
            for (Future<Boolean> future : futureList) {
                raftThreadPoolExecutor.execute(() -> {
                    try {
                        resList.add(future.get(500, MILLISECONDS));
                    } catch (Exception e) {
                        resList.add(false);
                    } finally {
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await(750, MILLISECONDS);
            } catch (InterruptedException e) {
                log.error("Node {}: Append entries thread was interrupted", peerConfig.getId());
            }

            for (Boolean aboolean : resList) {
                if (aboolean) {
                    appendRes.incrementAndGet();
                }
            }

            //更新commitIndex
            List<Long> matchIndexList = new ArrayList<>(leaderState.getMatchIndex().values());
            int medium = 0;
            if (matchIndexList.size() >= 2) {
                Collections.sort(matchIndexList);
                medium = matchIndexList.size() / 2;
            }
            Long N = matchIndexList.get(medium);
            AtomicReference<Object> res = new AtomicReference<>();
            if (N > volatileState.getCommitIndex()) {
                LogEntry entry = persistentState.getLogModule().read(N);
                long prevCommitIndex = volatileState.getCommitIndex();
                if (entry != null && entry.getTerm() == persistentState.getCurrentTerm()) {
                    volatileState.setCommitIndex(N);
                    //将旧的commitIndex到新的commitIndex之间的logEntry全部应用
                    for (long i = prevCommitIndex + 1; i <= N; i++) {
                        LogEntry committedEntry = persistentState.logModule.read(i);
                        commitChannel.offer(committedEntry);
                    }
                    Thread clearRes = new Thread(() -> {
                        for (long i = prevCommitIndex + 1; i <= N; i++) {
                            try {
                                res.set(applyChannel.take());
                            } catch (InterruptedException e) {
                                log.error("ClearRes was interrupted");
                            }
                        }
                    }, "clearRes");
                    clearRes.start();
                    try {
                        clearRes.join();
                    } catch (InterruptedException e) {
                        log.error("Interrupted exception");
                    }
                }
            }

            log.info("Node {}: Append entries RPC received consent from {} nodes", peerConfig.getId(), appendRes.get());
            if (appendRes.get() >= (peerGroup.getPeers().size() - 1) / 2) {
                if (logEntry.getIndex() > volatileState.getCommitIndex()) {
                    long prevCommitIndex = volatileState.getCommitIndex();
                    volatileState.setCommitIndex(logEntry.getIndex());
                    //提交到channel
                    for (long i = prevCommitIndex + 1; i <= logEntry.getIndex(); i++) {
                        commitChannel.offer(persistentState.logModule.read(i));
                    }
                    for (long i = prevCommitIndex + 1; i <= logEntry.getIndex(); i++) {
                        try {
                            res.set(applyChannel.take());
                        } catch (InterruptedException e) {
                            log.error("Interrupted exception");
                        }
                    }
                }
                return ClientResponse.builder().success(true).res(res.get()).build();
            } else {
                //失败之后不能重试，否则可能会响应给client false，但commit成功的情况
                return ClientResponse.builder().success(false).build();
            }

        }
        return null;
    }

    /**
     * 向其他节点追加日志
     *
     * @param peer
     * @param logEntry
     * @return
     */
    private Future<Boolean> appendEntriesToOtherNodes(Peer peer, LogEntry logEntry) {
        return raftThreadPoolExecutor.submit(() -> {
            long start = System.currentTimeMillis();
            long end = start;

            //保持与获取异步线程相同的时间，防止出现返回给client false之后依然进行新日志的复制
            while (end - start < 500) {
                if (nodeStatus != NodeStatusEnum.LEADER) {
                    return false;
                }
                AppendEntriesRequest request = AppendEntriesRequest.builder()
                        .term(persistentState.getCurrentTerm())
                        .leaderId(peerGroup.getLeader().getId())
                        .leaderCommit(volatileState.getCommitIndex())
                        .build();
                long nextIndex = leaderState.getNextIndex().get(peer);

                //初始化entries
                LinkedList<LogEntry> entries = new LinkedList<>();
                if (logEntry.getIndex() >= nextIndex) {
                    for (long i = nextIndex; i <= logEntry.getIndex(); i++) {
                        LogEntry entry = persistentState.getLogModule().read(i);
                        if (entry != null) {
                            entries.add(entry);
                        }
                    }
                } else {
                    entries.add(logEntry);
                    log.info("Node {}: appear:logEntry 's index<nextIndex", peerConfig.getId());
                }
                request.setEntries(entries);

                //设置prevLogEntry的信息
                long prevLogIndex = 0L;
                long prevLogTerm = 0L;
                LogEntry firstLogEntry = entries.getFirst();
                if (firstLogEntry != null) {
                    LogEntry prevLogEntry = persistentState.logModule.read(firstLogEntry.getIndex() - 1);
                    if (prevLogEntry != null) {
                        prevLogIndex = prevLogEntry.getIndex();
                        prevLogTerm = prevLogEntry.getTerm();
                    }
                }
                request.setPrevLogIndex(prevLogIndex);
                request.setPrevLogTerm(prevLogTerm);

                request.setCmd(Request.APPENDE_ENTRIES);

                //调用RPC
                try {
                    log.info("Node {}: Call append entries RPC, destination is Node {}", peerConfig.getId(), peer.getId());
                    AppendEntriesResponse response = (AppendEntriesResponse) getRaftRpcClient().send(request, peer.getAddr());
                    if (response == null) {
                        return false;
                    }
                    if (response.isSuccess()) {
                        //如果追加日志成功，更新对应的matchIndex和nextIndex
                        leaderState.getMatchIndex().put(peer, logEntry.getIndex());
                        leaderState.getNextIndex().put(peer, logEntry.getIndex() + 1);
                        return true;
                    } else {

                        if (response.getTerm() > persistentState.getCurrentTerm()) {
                            log.warn("Node {}: There comes a bigger term from Node {}. Self node will become follower. Max term:{},my term:{} ", peerConfig.getId(), peer.getId(), response.getTerm(), persistentState.getCurrentTerm());
                            nodeStatus = NodeStatusEnum.FOLLOWER;
                            persistentState.setCurrentTerm(response.getTerm());
                            return false;
                        } else {
                            //递减nextIndex进行重试
                            if (nextIndex == 0L) {
                                nextIndex = 1L;
                            }
                            leaderState.getNextIndex().put(peer, nextIndex - 1);
                            log.warn("Node {}: Node {} ‘s matchIndex didn't match, need to retry", peerConfig.getId(), peer.getId());
                        }
                    }
                    end = System.currentTimeMillis();
                } catch (RemotingException e) {
                    log.error("Node{}: Append entries RPC failed, destination is Node {}", peerConfig.getId(), peer.getId());
                    return false;
                }
            }
            return false;
        });
    }

    /**
     * 当一定时间内没有收到来自leader的心跳，就触发该任务
     */
    class ElectionTask implements Runnable {

        @Override
        public void run() {
            //如果当前节点为leader，则不触发超时选举
            if (nodeStatus == NodeStatusEnum.LEADER) {
                return;
            }
            long currentTime = System.currentTimeMillis();
            if (currentTime - prevElectionTime < electionTime) {
                return;
            }
            log.warn("Node {}: Take part in election,new term:{},last entry:{}", peerConfig.getId(), persistentState.getCurrentTerm() + 1, persistentState.getLogModule().getLast());
            //更新上一次选举时间
            prevElectionTime = System.currentTimeMillis();
            //更新election timeout
            electionTime = 750 + ThreadLocalRandom.current().nextInt(750);
            //任期号自增
            persistentState.setCurrentTerm(persistentState.getCurrentTerm() + 1);
            //转换为Candidate状态
            nodeStatus = NodeStatusEnum.CANDIDATE;
            //投票给自己
            persistentState.setVotedFor(peerConfig.getId());
            //请求其他节点给自己投票
            List<Peer> peers = peerGroup.getOtherPeers(peerConfig);
            ArrayList<Future<VoteResponse>> futureArrayList = new ArrayList<>();
            for (Peer peer : peers) {
                futureArrayList.add(raftThreadPoolExecutor.submit(() -> {
                    long lastTerm = 0L;
                    LogEntry last = persistentState.getLogModule().getLast();
                    if (last != null) {
                        lastTerm = last.getTerm();
                    }
                    Request request = VoteRequest.builder()
                            .term(persistentState.getCurrentTerm())
                            .candidateId(peerConfig.getId())
                            .lastLogIndex(persistentState.getLogModule().getLastIndex())
                            .lastLogTerm(lastTerm)
                            .build();
                    request.setCmd(Request.REQUEST_VOTE);

                    try {
                        log.info("Node {}: Call request vote RPC, destination is Node {}", peerConfig.getId(), peer.getId());
                        return (VoteResponse) getRaftRpcClient().send(request, peer.getAddr());
                    } catch (RemotingException e) {
                        log.error("Node {}: RequestVote RPC failed, destination is Node {}", peerConfig.getId(), peer.getId());
                    }
                    return null;
                }));
            }
            AtomicInteger electionRes = new AtomicInteger(0);
            CountDownLatch countDownLatch = new CountDownLatch(futureArrayList.size());

            for (Future future : futureArrayList) {
                raftThreadPoolExecutor.submit(() -> {
                    try {
                        VoteResponse response = (VoteResponse) future.get(500, TimeUnit.MILLISECONDS);
                        if (response == null) {
                            return;
                        }
                        boolean voteGranted = response.isVoteGranted();
                        if (voteGranted) {
                            electionRes.incrementAndGet();
                        } else {
                            long newTerm = response.getTerm();
                            if (newTerm > persistentState.getCurrentTerm()) {
                                persistentState.setCurrentTerm(newTerm);
                            }
                            return;
                        }
                    } catch (Exception e) {
                        log.error("Node {}: Get vote response failed", peerConfig.getId());

                    } finally {
                        countDownLatch.countDown();
                    }
                });
            }
            try {
                // 等待线程池中的线程将结果收集好之后再处理
                countDownLatch.await(750, MILLISECONDS);
            } catch (InterruptedException e) {
                log.warn("Node {}: InterruptedException By Master election Task", peerConfig.getId());
            }
            int finalElectionRes = electionRes.get();
            log.info("Node {}: Send {} vote requests,and {} nodes vote for me", peerConfig.getId(), futureArrayList.size(), finalElectionRes);
            //检测当前节点是否依然保持着Candidate的状态
            if (nodeStatus == NodeStatusEnum.FOLLOWER) {
                log.warn("Node {}: Transitions to follower due to receipt of a higher term datagram", peerConfig.getId());
                return;
            }
            if (finalElectionRes >= peers.size() / 2) {
                log.warn("Node {}: Become leader", peerConfig.getId());
                nodeStatus = NodeStatusEnum.LEADER;
                peerGroup.setLeader(peerConfig);
                //清除投票状态
                persistentState.setVotedFor(0);
                doThingsAfterBeingLeader();
            } else {
                //如果未收到半数的票，重新开始下一次选举
                persistentState.setVotedFor(0);
            }

        }
    }

    /**
     * 成为leader以后需要做的工作：
     * 初始化matchIndex和nextIndex
     */
    private void doThingsAfterBeingLeader() {
        Map<Peer, Long> matchIndex = new ConcurrentHashMap<>();
        Map<Peer, Long> nextIndex = new ConcurrentHashMap<>();
        for (Peer peer : peerGroup.getOtherPeers(peerConfig)) {
            nextIndex.put(peer, persistentState.getLogModule().getLastIndex() + 1);
            matchIndex.put(peer, 0L);
        }
        leaderState.setMatchIndex(matchIndex);
        leaderState.setNextIndex(nextIndex);
    }

    /**
     * 向其他节点发送心跳
     */
    class HeartBeatTask implements Runnable {

        @Override
        public void run() {
            //如果当前节点非leader，则不向其他节点发送心跳
            if (nodeStatus != NodeStatusEnum.LEADER) {
                return;
            }
            long currentTime = System.currentTimeMillis();
            //如果未超出一定时间，则不向其他节点发送心跳
            if (currentTime - prevHeartBeatTime < heartBeatTick) {
                return;
            }
            //即将发送心跳，更新prevHeartBeatTime
            prevHeartBeatTime = System.currentTimeMillis();
            //向系统中的其他所有节点发送心跳
            for (Peer peer : peerGroup.getOtherPeers(peerConfig)) {
                Request request = AppendEntriesRequest.builder()
                        .entries(null)
                        .leaderId(peerConfig.getId())
                        .term(persistentState.getCurrentTerm())
                        .build();
                request.setCmd(Request.APPENDE_ENTRIES);
                raftThreadPoolExecutor.execute(() -> {
                    AppendEntriesResponse response = null;
                    try {
                        response = (AppendEntriesResponse) getRaftRpcClient().send(request, peer.getAddr());
                        long term = response.getTerm();
                        if (term > persistentState.getCurrentTerm()) {
                            log.error("Node {}: There comes a bigger term. Self node will become follower. Max term:{},my term:{} ", peerConfig.getId(), term, persistentState.getCurrentTerm());
                            persistentState.setCurrentTerm(term);
                            persistentState.setVotedFor(0);
                            nodeStatus = NodeStatusEnum.FOLLOWER;
                        }
                    } catch (RemotingException e) {
                        log.error("Node {}: HeartBeat RPC failed. Destination is {}", peerConfig.getId(), peer.getId());
                    }

                });
            }
        }

    }

    /**
     * 将commitChannel中的logEntry应用到状态机中，并将结果写回到applyChannel
     */
    class Apply2StateMachineTask implements Runnable {

        @Override
        public void run() {
            while (true) {
                try {
                    LogEntry entry = commitChannel.take();
                    Object res = stateMachine.apply(entry);
                    volatileState.setLastApplied(volatileState.getLastApplied() + 1);
                    log.info("Node {}: {} has been applied to stateMachine", peerConfig.getId(), entry);
                    if (nodeStatus == NodeStatusEnum.LEADER) {
                        applyChannel.offer(res);
                    }
                } catch (InterruptedException e) {
                    log.error("Interrupted Exception");
                }
            }
        }
    }
}
